{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "jgV5DXxPhU3h"
   },
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/asteroid-team/asteroid/blob/master/notebooks/01_APIOverview.ipynb)\n",
    "\n",
    "### Introduction\n",
    "Asteroid is an open-source, community-based toolkit made to design, train, evaluate, use and share audio source separation models such as Deep clustering ([Hershey et al.](https://arxiv.org/abs/1508.04306)), ConvTasNet ([Luo et al.](https://arxiv.org/abs/1809.07454)) DPRNN ([Luo et al.](https://arxiv.org/abs/1910.06379)) etc..  \n",
    "Along with the models, Asteroid provides building blocks, losses, metrics and datasets commonly used in source separation. This makes it easy to design new source separation models and benchmark them against others ! \n",
    "\n",
    "For training, Asteroid relies of the great [PyTorchLightning](https://github.com/PyTorchLightning/pytorch-lightning), which handles automatic distributed training, logging, experiment resume and much more, be sure to check it out! For the rest, it's native [PyTorch](https://pytorch.org).\n",
    "\n",
    "Enough talking, let's start !"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "J9tFIeeXc9A6"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001B[33mWARNING: You are using pip version 19.3.1; however, version 20.2.2 is available.\r\n",
      "You should consider upgrading via the 'pip install --upgrade pip' command.\u001B[0m\r\n"
     ]
    }
   ],
   "source": [
    "# First off, install asteroid\n",
    "!pip install git+https://github.com/asteroid-team/asteroid --quiet\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### After installing requirements, you need to Restart Runtime (Ctrl + M).\n",
    "\n",
    "Else it will fail to import asteroid"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "NEMccxSyO7wd"
   },
   "source": [
    "### Waveform transformations & features\n",
    "Time-frequency transformations are often performed on waveforms before feeding them to source separation models. Most of them can be formulated as convolutions with specific (learned or not) filterbank. Their inverses, mapping back to time domain, can be formulated as transposed convolution. \n",
    "Asteroid proposes a unified view of this transformations, which is implemented with the classes `Filterbank`, `Encoder` and `Decoder`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "3alvBplu8ckU"
   },
   "source": [
    "The `Filterbank` object is the one holding the actual filters that are used to compute the transforms. `Encoder` and `Decoder` are applied on top to provide method to go back and forth from waveform to time-frequency domain.\n",
    "\n",
    "A common example is the one of the STFT, that can be defined as follows:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "3Kfof5Ik9M7y"
   },
   "outputs": [],
   "source": [
    "from asteroid_filterbanks import STFTFB, Encoder, Decoder\n",
    "# First, instantiate the STFT filterbank\n",
    "fb = STFTFB(n_filters=256, kernel_size=128, stride=64)\n",
    "# Make an encoder out of it, forward some waveform through it.\n",
    "encoder = Encoder(fb)\n",
    "# Same for decoder\n",
    "decoder_fb = STFTFB(n_filters=256, kernel_size=128, stride=64)\n",
    "decoder = Decoder(decoder_fb)\n",
    "\n",
    "# The preceding lines can also be obtained faster with these lines\n",
    "from asteroid_filterbanks import make_enc_dec\n",
    "encoder, decoder = make_enc_dec('stft', n_filters=256, \n",
    "                                kernel_size=128, stride=64)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "mWN27rzg-DLi"
   },
   "source": [
    "From there, the interface of `Encoder` is the same as the one from `torch.nn.Conv1d` and `Decoder` as `torch.nn.ConvTranspose1d`, and a waveform-like object can be transformed like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "DLSk_QDK-C24"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "# Waveform-like\n",
    "wav = torch.randn(2, 1, 16000)\n",
    "# Time-frequency representation\n",
    "tf_rep = encoder(wav)\n",
    "# Back to time domain\n",
    "wav_back = decoder(tf_rep)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "McWJACMQ-6Eo"
   },
   "source": [
    "More info on automatic pseudo-inverse, how to define your own filterbanks etc.. can be found in the\n",
    "[Filterbank notebook](https://github.com/asteroid-team/asteroid/blob/master/notebooks/02_Filterbank.ipynb)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Ni8S1Owm_J32"
   },
   "source": [
    "### Masker network & Separation models \n",
    "Asteroid aims at providing most state-of-the-art masker neural network. \n",
    "Some of these masking networks and/or separation models share building blocks such as residual LSTMs or D-Conv-based convolutional blocks. \n",
    "Asteroid provides these building blocks as well as common masker networks with building blocks already assembled (eg. `TDConvNet` or `DPRNN`)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "gGSI1fn6AqkZ"
   },
   "source": [
    "These blocks are already configured optimally according to the corresponding papers, just import them and run ! "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "5bfGL-Q-Ap4v"
   },
   "outputs": [],
   "source": [
    "from asteroid.masknn import TDConvNet\n",
    "# We only need to specify the number of input channels\n",
    "# and the number of sources we want to estimate.\n",
    "masker = TDConvNet(in_chan=128, n_src=2)\n",
    "\n",
    "# Now, we can use it to estimate some masks!\n",
    "tf_rep = torch.randn(2, 128, 10)\n",
    "wav_back = masker(tf_rep)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "JkDMqRgIDgY8"
   },
   "source": [
    "Let's put the encoder, masker and decoder together in an `nn.Module` to make it all simple."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "E-sHcI4rDyf0"
   },
   "outputs": [],
   "source": [
    "from asteroid_filterbanks import make_enc_dec\n",
    "\n",
    "class Model(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        # Encoder and Decode in \"one line\"\n",
    "        self.enc, self.dec = make_enc_dec(\n",
    "            'stft', n_filters=256, kernel_size=128, stride=64\n",
    "            )\n",
    "        # # Mask network from ConvTasNet in one line.\n",
    "        self.masker = TDConvNet(in_chan=self.enc.n_feats_out, \n",
    "                                n_src=2)\n",
    "    \n",
    "    def forward(self, wav):\n",
    "        # Simplified forward\n",
    "        tf_rep = self.enc(wav)\n",
    "        masks = self.masker(tf_rep)\n",
    "        wavs_out = self.dec(tf_rep.unsqueeze(1) * masks)\n",
    "        return wavs_out\n",
    "\n",
    "\n",
    "# Define and forward \n",
    "stft_conv_tasnet = Model()\n",
    "wav_out = stft_conv_tasnet(torch.randn(1, 1, 16000))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "x-pwTAfbFHMZ"
   },
   "source": [
    "Actually, for models like ConvTasNet, they can directly be imported and used from asteroid like this :\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "6n3temjFIV5H"
   },
   "outputs": [],
   "source": [
    "from asteroid import ConvTasNet\n",
    "model = ConvTasNet(n_src=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "_icj67ZAKpFS"
   },
   "source": [
    "### Datasets and DataLoader\n",
    "We support several source separation datasets, you can find more information on them in the docs. \n",
    "Note that their is no common API between them, preparing the data in the format expected by the `Dataset` is the role of the recipes.\n",
    "\n",
    "In order to experiment easily, we added a small part of LibriMix for direct download."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "mCkcZ7aXOSOh"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Drop 0 utterances from 800 (shorter than 3 seconds)\n",
      "Drop 0 utterances from 200 (shorter than 3 seconds)\n"
     ]
    }
   ],
   "source": [
    "from asteroid.data import LibriMix\n",
    "\n",
    "train_set, val_set = LibriMix.mini_from_download(task='sep_clean')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "RwA632V3XaBV"
   },
   "source": [
    "### Loss functions\n",
    "Asteroid provides several loss functions that are commonly used for source separation or speech enhancement. More importantly, we also provide `PITLossWrapper`, an efficient wrapper that can turn any loss function into a permutation invariant loss. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "O5NijgIOX1CZ"
   },
   "source": [
    "For example, defining a permuatation invariant si-sdr loss, run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "vg8eZXmiYAEN"
   },
   "outputs": [],
   "source": [
    "from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr\n",
    "\n",
    "\n",
    "loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "TzQRTk6QYKd_"
   },
   "source": [
    "You can find more info about this in the [PIT loss tutorial](https://github.com/asteroid-team/asteroid/blob/master/notebooks/03_PITLossWrapper.ipynb)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "2gPbCAjGIJNC"
   },
   "source": [
    "### Training\n",
    "\n",
    "For training, Asteroid relies on PyTorchLightning which automatizes almost everything for us. We have a thin wrapper around it to make things even simpler.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "D2XyR6K5OBvS"
   },
   "source": [
    "#### Putting all ingredients together with `System`\n",
    "To use PyTorchLightning, we need to define all the ingredients (dataloaders, model, loss functions, optimizers, etc..) into one object, the `LightningModule`. In order to keep things separate and re-usable, and to reduce boilerplate, we define a sub-class, `System`, which expects these ingredients separately. \n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "fE7m8VunUmM4"
   },
   "source": [
    "Additionally, `LightningModule` needs to expose the `training_step` and `validation_step` functions. It is usual for these functions to be shared or really similar so we grouped them under `common_step`.\n",
    "```\n",
    "class System(pl.LightningModule):\n",
    "    def __init__(self, model, optimizer, loss_func, train_loader,\n",
    "                 val_loader=None, scheduler=None, config=None):\n",
    "      ...\n",
    "\n",
    "    def common_step(self, batch, batch_nb, train=True):\n",
    "        inputs, targets = batch\n",
    "        est_targets = self(inputs)\n",
    "        loss = self.loss_func(est_targets, targets)\n",
    "        return loss\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "iBuioyXlXGmU"
   },
   "source": [
    "#### Example training script"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "P0V-TatleOUi"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:lightning:Running in fast_dev_run mode: will run a full train, val and test loop using a single batch\n",
      "INFO:lightning:GPU available: False, used: False\n",
      "INFO:lightning:\n",
      "   | Name                              | Type           | Params\n",
      "-----------------------------------------------------------------\n",
      "0  | model                             | ConvTasNet     | 1 M   \n",
      "1  | model.encoder                     | Encoder        | 8 K   \n",
      "2  | model.encoder.filterbank          | FreeFB         | 8 K   \n",
      "3  | model.masker                      | TDConvNet      | 1 M   \n",
      "4  | model.masker.bottleneck           | Sequential     | 66 K  \n",
      "5  | model.masker.bottleneck.0         | GlobLN         | 1 K   \n",
      "6  | model.masker.bottleneck.1         | Conv1d         | 65 K  \n",
      "7  | model.masker.TCN                  | ModuleList     | 1 M   \n",
      "8  | model.masker.TCN.0                | Conv1DBlock    | 201 K \n",
      "9  | model.masker.TCN.0.shared_block   | Sequential     | 70 K  \n",
      "10 | model.masker.TCN.0.shared_block.0 | Conv1d         | 66 K  \n",
      "11 | model.masker.TCN.0.shared_block.1 | PReLU          | 1     \n",
      "12 | model.masker.TCN.0.shared_block.2 | GlobLN         | 1 K   \n",
      "13 | model.masker.TCN.0.shared_block.3 | Conv1d         | 2 K   \n",
      "14 | model.masker.TCN.0.shared_block.4 | PReLU          | 1     \n",
      "15 | model.masker.TCN.0.shared_block.5 | GlobLN         | 1 K   \n",
      "16 | model.masker.TCN.0.res_conv       | Conv1d         | 65 K  \n",
      "17 | model.masker.TCN.0.skip_conv      | Conv1d         | 65 K  \n",
      "18 | model.masker.TCN.1                | Conv1DBlock    | 201 K \n",
      "19 | model.masker.TCN.1.shared_block   | Sequential     | 70 K  \n",
      "20 | model.masker.TCN.1.shared_block.0 | Conv1d         | 66 K  \n",
      "21 | model.masker.TCN.1.shared_block.1 | PReLU          | 1     \n",
      "22 | model.masker.TCN.1.shared_block.2 | GlobLN         | 1 K   \n",
      "23 | model.masker.TCN.1.shared_block.3 | Conv1d         | 2 K   \n",
      "24 | model.masker.TCN.1.shared_block.4 | PReLU          | 1     \n",
      "25 | model.masker.TCN.1.shared_block.5 | GlobLN         | 1 K   \n",
      "26 | model.masker.TCN.1.res_conv       | Conv1d         | 65 K  \n",
      "27 | model.masker.TCN.1.skip_conv      | Conv1d         | 65 K  \n",
      "28 | model.masker.TCN.2                | Conv1DBlock    | 201 K \n",
      "29 | model.masker.TCN.2.shared_block   | Sequential     | 70 K  \n",
      "30 | model.masker.TCN.2.shared_block.0 | Conv1d         | 66 K  \n",
      "31 | model.masker.TCN.2.shared_block.1 | PReLU          | 1     \n",
      "32 | model.masker.TCN.2.shared_block.2 | GlobLN         | 1 K   \n",
      "33 | model.masker.TCN.2.shared_block.3 | Conv1d         | 2 K   \n",
      "34 | model.masker.TCN.2.shared_block.4 | PReLU          | 1     \n",
      "35 | model.masker.TCN.2.shared_block.5 | GlobLN         | 1 K   \n",
      "36 | model.masker.TCN.2.res_conv       | Conv1d         | 65 K  \n",
      "37 | model.masker.TCN.2.skip_conv      | Conv1d         | 65 K  \n",
      "38 | model.masker.TCN.3                | Conv1DBlock    | 201 K \n",
      "39 | model.masker.TCN.3.shared_block   | Sequential     | 70 K  \n",
      "40 | model.masker.TCN.3.shared_block.0 | Conv1d         | 66 K  \n",
      "41 | model.masker.TCN.3.shared_block.1 | PReLU          | 1     \n",
      "42 | model.masker.TCN.3.shared_block.2 | GlobLN         | 1 K   \n",
      "43 | model.masker.TCN.3.shared_block.3 | Conv1d         | 2 K   \n",
      "44 | model.masker.TCN.3.shared_block.4 | PReLU          | 1     \n",
      "45 | model.masker.TCN.3.shared_block.5 | GlobLN         | 1 K   \n",
      "46 | model.masker.TCN.3.res_conv       | Conv1d         | 65 K  \n",
      "47 | model.masker.TCN.3.skip_conv      | Conv1d         | 65 K  \n",
      "48 | model.masker.TCN.4                | Conv1DBlock    | 201 K \n",
      "49 | model.masker.TCN.4.shared_block   | Sequential     | 70 K  \n",
      "50 | model.masker.TCN.4.shared_block.0 | Conv1d         | 66 K  \n",
      "51 | model.masker.TCN.4.shared_block.1 | PReLU          | 1     \n",
      "52 | model.masker.TCN.4.shared_block.2 | GlobLN         | 1 K   \n",
      "53 | model.masker.TCN.4.shared_block.3 | Conv1d         | 2 K   \n",
      "54 | model.masker.TCN.4.shared_block.4 | PReLU          | 1     \n",
      "55 | model.masker.TCN.4.shared_block.5 | GlobLN         | 1 K   \n",
      "56 | model.masker.TCN.4.res_conv       | Conv1d         | 65 K  \n",
      "57 | model.masker.TCN.4.skip_conv      | Conv1d         | 65 K  \n",
      "58 | model.masker.TCN.5                | Conv1DBlock    | 201 K \n",
      "59 | model.masker.TCN.5.shared_block   | Sequential     | 70 K  \n",
      "60 | model.masker.TCN.5.shared_block.0 | Conv1d         | 66 K  \n",
      "61 | model.masker.TCN.5.shared_block.1 | PReLU          | 1     \n",
      "62 | model.masker.TCN.5.shared_block.2 | GlobLN         | 1 K   \n",
      "63 | model.masker.TCN.5.shared_block.3 | Conv1d         | 2 K   \n",
      "64 | model.masker.TCN.5.shared_block.4 | PReLU          | 1     \n",
      "65 | model.masker.TCN.5.shared_block.5 | GlobLN         | 1 K   \n",
      "66 | model.masker.TCN.5.res_conv       | Conv1d         | 65 K  \n",
      "67 | model.masker.TCN.5.skip_conv      | Conv1d         | 65 K  \n",
      "68 | model.masker.TCN.6                | Conv1DBlock    | 201 K \n",
      "69 | model.masker.TCN.6.shared_block   | Sequential     | 70 K  \n",
      "70 | model.masker.TCN.6.shared_block.0 | Conv1d         | 66 K  \n",
      "71 | model.masker.TCN.6.shared_block.1 | PReLU          | 1     \n",
      "72 | model.masker.TCN.6.shared_block.2 | GlobLN         | 1 K   \n",
      "73 | model.masker.TCN.6.shared_block.3 | Conv1d         | 2 K   \n",
      "74 | model.masker.TCN.6.shared_block.4 | PReLU          | 1     \n",
      "75 | model.masker.TCN.6.shared_block.5 | GlobLN         | 1 K   \n",
      "76 | model.masker.TCN.6.res_conv       | Conv1d         | 65 K  \n",
      "77 | model.masker.TCN.6.skip_conv      | Conv1d         | 65 K  \n",
      "78 | model.masker.TCN.7                | Conv1DBlock    | 201 K \n",
      "79 | model.masker.TCN.7.shared_block   | Sequential     | 70 K  \n",
      "80 | model.masker.TCN.7.shared_block.0 | Conv1d         | 66 K  \n",
      "81 | model.masker.TCN.7.shared_block.1 | PReLU          | 1     \n",
      "82 | model.masker.TCN.7.shared_block.2 | GlobLN         | 1 K   \n",
      "83 | model.masker.TCN.7.shared_block.3 | Conv1d         | 2 K   \n",
      "84 | model.masker.TCN.7.shared_block.4 | PReLU          | 1     \n",
      "85 | model.masker.TCN.7.shared_block.5 | GlobLN         | 1 K   \n",
      "86 | model.masker.TCN.7.res_conv       | Conv1d         | 65 K  \n",
      "87 | model.masker.TCN.7.skip_conv      | Conv1d         | 65 K  \n",
      "88 | model.masker.mask_net             | Sequential     | 132 K \n",
      "89 | model.masker.mask_net.0           | PReLU          | 1     \n",
      "90 | model.masker.mask_net.1           | Conv1d         | 132 K \n",
      "91 | model.masker.output_act           | Sigmoid        | 0     \n",
      "92 | model.decoder                     | Decoder        | 8 K   \n",
      "93 | model.decoder.filterbank          | FreeFB         | 8 K   \n",
      "94 | model.enc_activation              | Identity       | 0     \n",
      "95 | loss_func                         | PITLossWrapper | 0     \n",
      "96 | loss_func.loss_func               | PairwiseNegSDR | 0     \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Drop 0 utterances from 800 (shorter than 3 seconds)\n",
      "Drop 0 utterances from 200 (shorter than 3 seconds)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/mparient/.virtualenvs/pytorch3.6/lib/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:23: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` in the `DataLoader` init to improve performance.\n",
      "  warnings.warn(*args, **kwargs)\n",
      "/home/mparient/.virtualenvs/pytorch3.6/lib/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:23: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` in the `DataLoader` init to improve performance.\n",
      "  warnings.warn(*args, **kwargs)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5519898df2dc49b5b8eb5962ac794beb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:lightning:Detected KeyboardInterrupt, attempting graceful shutdown...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from torch.optim import Adam\n",
    "from torch.utils.data import DataLoader\n",
    "import pytorch_lightning as pl\n",
    "\n",
    "from asteroid.data import LibriMix\n",
    "from asteroid.engine.system import System\n",
    "from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr\n",
    "from asteroid import ConvTasNet\n",
    "\n",
    "train_set, val_set = LibriMix.mini_from_download(task='sep_clean')\n",
    "train_loader = DataLoader(train_set, batch_size=4, drop_last=True)\n",
    "val_loader = DataLoader(val_set, batch_size=4, drop_last=True)\n",
    "\n",
    "# Define model and optimizer (one repeat to be faster)\n",
    "model = ConvTasNet(n_src=2, n_repeats=1)\n",
    "optimizer = Adam(model.parameters(), lr=1e-3)\n",
    "# Define Loss function.\n",
    "loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')\n",
    "# Define System\n",
    "system = System(model=model, loss_func=loss_func, optimizer=optimizer,\n",
    "                train_loader=train_loader, val_loader=val_loader)\n",
    "# Define lightning trainer, and train\n",
    "trainer = pl.Trainer(fast_dev_run=True)\n",
    "trainer.fit(system)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "0v76oPR2NnUO"
   },
   "source": [
    "#### Extending `System`\n",
    "If your model or data is a bit different, changing `System` is easy, just overwrite the `common_step` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "kRzawfhvTNVH"
   },
   "outputs": [],
   "source": [
    "# Example of how simple it is to define a new System with \n",
    "# different training dynamic.\n",
    "class YourSystem(System):\n",
    "    def common_step(self, batch, batch_nb, train=True):\n",
    "        # Your DataLoader returns three tensors\n",
    "        inputs, some_other_input, targets = batch\n",
    "        # Your model returns two.\n",
    "        est_targets, some_other_output = self(inputs, some_other_input)\n",
    "        if train:\n",
    "            # Your loss takes three argument\n",
    "            loss = self.loss_func(est_targets, targets, cond=some_other_output)\n",
    "        else:\n",
    "            # At validation time, you don't want cond \n",
    "            loss = self.loss_func(est_targets, targets)\n",
    "        return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "TDE_DnmfTX_U"
   },
   "source": [
    "Of course, Asteroid is not limited to using `System` as this is pure PyTorchLightning and more complicated use cases might not benefit from `System`. In this case, writing a `LightningModule` would be the way to go !"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [
    "jgV5DXxPhU3h",
    "NEMccxSyO7wd",
    "Ni8S1Owm_J32",
    "D2XyR6K5OBvS",
    "0v76oPR2NnUO"
   ],
   "name": "01-APIOverview.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
